{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 试一试 gym"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score:  25.0\n"
     ]
    }
   ],
   "source": [
    "# try_gym.py\n",
    "# https://geektutu.com\n",
    "import gym  # 0.12.5\n",
    "import random\n",
    "import time\n",
    "\n",
    "env = gym.make(\"CartPole-v0\")  # 加载游戏环境\n",
    "\n",
    "state = env.reset()\n",
    "score = 0\n",
    "while True:\n",
    "    time.sleep(0.1)\n",
    "    env.render()   # 显示画面\n",
    "    action = random.randint(0, 1)  # 随机选择一个动作 0 或 1\n",
    "    state, reward, done, _ = env.step(action)  # 执行这个动作\n",
    "    score += reward     # 每回合的得分\n",
    "    if done:       # 游戏结束\n",
    "        print('score: ', score)  # 打印分数\n",
    "        break\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense (Dense)                (None, 64)                320       \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 20)                1300      \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 2)                 42        \n",
      "=================================================================\n",
      "Total params: 1,662\n",
      "Trainable params: 1,662\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# train.py\n",
    "# https://geektutu.com\n",
    "import random\n",
    "import gym\n",
    "import numpy as np\n",
    "from tensorflow.keras import models, layers\n",
    "\n",
    "env = gym.make(\"CartPole-v0\")  # 加载游戏环境\n",
    "\n",
    "STATE_DIM, ACTION_DIM = 4, 2  # State 维度 4, Action 维度 2\n",
    "model = models.Sequential([\n",
    "    layers.Dense(64, input_dim=STATE_DIM, activation='relu'),\n",
    "    layers.Dense(20, activation='relu'),\n",
    "    layers.Dense(ACTION_DIM, activation='linear')\n",
    "])\n",
    "model.summary()  # 打印神经网络信息"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data_one_episode():\n",
    "    '''生成单次游戏的训练数据'''\n",
    "    x, y, score = [], [], 0\n",
    "    state = env.reset()\n",
    "    while True:\n",
    "        action = random.randrange(0, 2)\n",
    "        x.append(state)\n",
    "        y.append([1, 0] if action == 0 else [0, 1]) # 记录数据\n",
    "        state, reward, done, _ = env.step(action) # 执行动作\n",
    "        score += reward\n",
    "        if done:\n",
    "            break\n",
    "    return x, y, score\n",
    "\n",
    "\n",
    "def generate_training_data(expected_score=100):\n",
    "    '''# 生成N次游戏的训练数据，并进行筛选，选择 > 100 的数据作为训练集'''\n",
    "    data_X, data_Y, scores = [], [], []\n",
    "    for i in range(10000):\n",
    "        x, y, score = generate_data_one_episode()\n",
    "        if score > expected_score:\n",
    "            data_X += x\n",
    "            data_Y += y\n",
    "            scores.append(score)\n",
    "    print('dataset size: {}, max score: {}'.format(len(data_X), max(scores)))\n",
    "    return np.array(data_X), np.array(data_Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练并保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset size: 749, max score: 120.0\n",
      "Train on 749 samples\n",
      "Epoch 1/5\n",
      "749/749 [==============================] - 0s 210us/sample - loss: 0.3735\n",
      "Epoch 2/5\n",
      "749/749 [==============================] - 0s 37us/sample - loss: 0.2777\n",
      "Epoch 3/5\n",
      "749/749 [==============================] - 0s 37us/sample - loss: 0.2523\n",
      "Epoch 4/5\n",
      "749/749 [==============================] - 0s 40us/sample - loss: 0.2403\n",
      "Epoch 5/5\n",
      "749/749 [==============================] - 0s 36us/sample - loss: 0.2355\n"
     ]
    }
   ],
   "source": [
    "data_X, data_Y = generate_training_data()\n",
    "model.compile(loss='mse', optimizer='adam')\n",
    "model.fit(data_X, data_Y, epochs=5)\n",
    "model.save('CartPole-v0-nn.h5')  # 保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试/预测模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using nn, score:  196.0\n",
      "using nn, score:  200.0\n",
      "using nn, score:  200.0\n",
      "using nn, score:  200.0\n",
      "using nn, score:  200.0\n"
     ]
    }
   ],
   "source": [
    "# predict.py\n",
    "# https://geektutu.com\n",
    "import time\n",
    "import numpy as np\n",
    "import gym\n",
    "from tensorflow.keras import models\n",
    "\n",
    "\n",
    "saved_model = models.load_model('CartPole-v0-nn.h5')  # 加载模型\n",
    "env = gym.make(\"CartPole-v0\")  # 加载游戏环境\n",
    "\n",
    "for i in range(5):\n",
    "    state = env.reset()\n",
    "    score = 0\n",
    "    while True:\n",
    "        time.sleep(0.01)\n",
    "        env.render()   # 显示画面\n",
    "        action = np.argmax(saved_model.predict(np.array([state]))[0])  # 预测动作\n",
    "        state, reward, done, _ = env.step(action)  # 执行这个动作\n",
    "        score += reward     # 每回合的得分\n",
    "        if done:       # 游戏结束\n",
    "            print('using nn, score: ', score)  # 打印分数\n",
    "            break\n",
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
